import gym
import gym.error
import numpy as np
from gym.spaces import Box
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import scipy.stats

BASE_ROBUST_RL_PRESET = {
    "standard_agent_ids": [0],
    "adversary_agent_id": 1,
}

ORACLE_ANTAGONIST_CONSTRAINED_ROBUST_RL_PRESET = {
    "standard_agent_ids": [0],
    "adversary_agent_id": 1,
}

CONSTRAINED_ROBUST_RL_PRESET = {  
    "standard_agent_ids": [0, 2],
    "adversary_agent_id": 1,
}

CONSTRAINED_ROBUST_RL_MULTI_EPISODE_PRESET = {
    "standard_agent_ids": [0, 2, 3, 4, 5, 6],
    # protagonist is 0, 3, 5
    # antagonist is 2, 4, 6
    "adversary_agent_id": 1,
}

DOMAIN_RANDOMIZATION_PRESET = {
    "standard_agent_ids": [0],
    "adversary_agent_id": None,
}

SINGLE_RUN_PAIRED_PRESET = {
    "standard_agent_ids": [0, 2],
    "adversary_agent_id": 1,
}


class TimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super(TimeLimit, self).__init__(env)
        if max_episode_steps is None and self.env.spec is not None:
            max_episode_steps = env.spec.max_episode_steps
        if self.env.spec is not None:
            self.env.spec.max_episode_steps = max_episode_steps
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = None

        self.observation_space = Box(low=np.asarray([*self.env.observation_space.low, 0.0]),
                                     high=np.asarray([*self.env.observation_space.high, 1.0]),
                                     dtype=np.float32)

    def _add_timestep_to_obs(self, obs):
        return np.concatenate((obs, [self._elapsed_steps / self._max_episode_steps]))

    def step(self, action):
        assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()"
        observation, reward, done, info = self.env.step(action)
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            info['TimeLimit.truncated'] = not done
            done = True

        return self._add_timestep_to_obs(observation), reward, done, info

    def reset(self, **kwargs):
        self._elapsed_steps = 0
        return self._add_timestep_to_obs(self.env.reset(**kwargs))


class _JointAction(object):
    def __init__(self, pro=None, adv=None):
        self.pro = pro
        self.adv = adv


class MujocoMultiAgentEnv(MultiAgentEnv):

    def __init__(self, env_config=None):

        self.base_gym_name = env_config.get("base_gym_name", "HopperAdv-v1")
        gym_spec = gym.spec(self.base_gym_name)

        max_episode_length = int(env_config.get("max_episode_length", 1000))
        gym_env = TimeLimit(gym_spec.make(), max_episode_steps=max_episode_length)

        self._env = gym_env

        # Original RARL paper was 5.0
        self._env_max_force = env_config.get("adv_max_force", 20.0)
        self._env.set_adv_max_force(self._env_max_force)

        self._adversary_provides_distribution = bool(env_config.get("adversary_provides_distribution", False))
        self._x_axis_only = bool(env_config.get("x_axis_only", False))
        self._x_axis_alpha_override = env_config.get("x_axis_alpha_override", None)
        self._x_axis_beta_override = env_config.get("x_axis_beta_override", None)

        self._default_protag_reward_this_episode = None

        self.standard_agent_ids = env_config["standard_agent_ids"]
        self.adversary_agent_id = env_config["adversary_agent_id"]

        if self.adversary_agent_id is not None:
            self.all_agent_ids = [self.adversary_agent_id, *self.standard_agent_ids]
        elif len(self.standard_agent_ids) > 0:
            self.all_agent_ids = self.standard_agent_ids
        else:
            raise ValueError("There are no agents in this env configuration.")

        if len(self.standard_agent_ids) == 0:
            raise ValueError("There are no standard agents in this env configuration.")

        self._acting_agent_idx = 0

        if self._adversary_provides_distribution:
            if len(self._env.adv_action_space.shape) != 1:
                raise NotImplementedError("adversary_provides_distribution is only compatible with base action_spaces "
                                          "of length 1. "
                                          f"The action space {self._env.adv_action_space.shape} was provided.")
            if self._x_axis_only:
                adversary_action_space = Box(low=-1.0, high=1.0, shape=(self._env.adv_action_space.shape[0],),
                                             dtype=np.float32)
            else:
                # Double the action space size to define beta distributions for each base action space dimension.
                adversary_action_space = Box(low=-1.0, high=1.0, shape=(self._env.adv_action_space.shape[0]*2,),
                                             dtype=np.float32)
            self._adversary_distribution_params = None
        else:
            adversary_action_space = Box(low=-1.0, high=1.0, shape=self._env.adv_action_space.shape, dtype=np.float32)

        self.action_space = {
            "adversary": adversary_action_space,
            "standard_agent": Box(low=-1.0, high=1.0, shape=self._env.pro_action_space.shape, dtype=np.float32),
        }

        if self._adversary_provides_distribution:
            adversary_obs_space = Box(low=0.0, high=0.0, shape=(1,), dtype=np.float32)
        else:
            adversary_obs_space = self._env.observation_space

        self.observation_space = {
            "adversary": adversary_obs_space,
            "standard_agent_partial": self._env.observation_space,
        }

        self.agent_ids_to_action_spaces = {agent_id: self._env.pro_action_space for agent_id in self.standard_agent_ids}
        self.agent_ids_to_action_spaces[self.adversary_agent_id] = adversary_action_space

    def _post_process_obs(self, observation: np.ndarray) -> np.ndarray:
        # maybe need to normalize?
        return observation

    def reset(self):
        """Resets the env and returns observations from ready agents.

        Returns:
            obs (dict): New observations for each ready agent.
        """

        self._acting_agent_idx = 0

        if self._adversary_provides_distribution:
            self._default_protag_reward_this_episode = 0.0

        if self._adversary_provides_distribution and self.adversary_agent_id is not None:
            acting_agent_id = self.all_agent_ids[self._acting_agent_idx]
            assert acting_agent_id == 1
            observation = np.zeros(shape=self.observation_space["adversary"].shape, dtype=np.float32)
            obs_dict = {acting_agent_id: observation}
        else:
            acting_agent_id = self.standard_agent_ids[self._acting_agent_idx]
            assert acting_agent_id == 0
            original_base_obs = self._env.reset()
            observation = self._post_process_obs(observation=original_base_obs)
            obs_dict = {acting_agent_id: observation}
            if self.adversary_agent_id is not None:
                # Adversary simultaneously makes same sequential observations as standard agent
                obs_dict[self.adversary_agent_id] = observation
            elif self._adversary_provides_distribution:
                # domain randomization
                adversary_action = self.action_space["adversary"].sample()
                # shift from the [-1, 1] range to the (0, 10] range to be used as Beta distribution params
                adversary_action = ((adversary_action + 1.0) / 2.0) * 10.0
                adversary_action[adversary_action == 0] = 1e-8
                assert (adversary_action > 0.0).all() and (adversary_action <= 10.0).all(), adversary_action
                self._adversary_distribution_params = adversary_action
            assert observation is not None

        return obs_dict

    def step(self, action_dict):
        if self._adversary_provides_distribution:
            return self._step_with_adversary_providing_distribution(action_dict=action_dict)
        else:
            return self._step_with_simultaneous_sequential_adversary(action_dict=action_dict)

    @staticmethod
    def _denormalize_action(action: np.ndarray, original_space: Box) -> np.ndarray:
        assert action in Box(low=-1.0, high=1.0, shape=action.shape), action
        assert np.shape(original_space.low)[0] > 1, original_space.low
        assert np.shape(original_space.high)[0] > 1, original_space.high
        action = (action + 1.0) / 2.0
        assert (action > 0.0).all() and (action <= 1.0).all(), action
        space_range = np.asarray(original_space.high) - np.asarray(original_space.low)
        action = (action * space_range) + original_space.low
        assert action in original_space, action
        return action

    def _step_with_simultaneous_sequential_adversary(self, action_dict):
        """Returns observations from ready agents.

        The returns are dicts mapping from agent_id strings to values. The
        number of agents in the env can vary over time.

        Returns
        -------
            obs (dict): New observations for each ready agent.
            rewards (dict): Reward values for each ready agent. If the
                episode is just started, the value will be None.
            dones (dict): Done values for each ready agent. The special key
                "__all__" (required) is used to indicate env termination.
            infos (dict): Optional info values for each agent id.
        """

        acting_agent_id = self.standard_agent_ids[self._acting_agent_idx]
        if acting_agent_id not in action_dict:
            raise ValueError(f"Acting agent id: {acting_agent_id} wasn't found in the action dict.")

        if self.adversary_agent_id is not None and self.adversary_agent_id not in action_dict:
            raise ValueError(f"Adversary agent id: {self.adversary_agent_id} wasn't found in the action dict.")

        pro_action = action_dict[acting_agent_id]
        if pro_action not in self.agent_ids_to_action_spaces[acting_agent_id]:
            raise ValueError(f"Action {pro_action} for agent id {acting_agent_id} wasn't a part of it's "
                             f"assigned action space: {self.agent_ids_to_action_spaces[acting_agent_id]}")

        adv_action = action_dict[self.adversary_agent_id] if self.adversary_agent_id is not None else self.agent_ids_to_action_spaces[self.adversary_agent_id].sample()
        if adv_action not in self.agent_ids_to_action_spaces[self.adversary_agent_id]:
            raise ValueError(f"Action {adv_action} for adversary agent id {self.adversary_agent_id} wasn't a part of it's "
                             f"assigned action space: {self.agent_ids_to_action_spaces[self.adversary_agent_id]}")


        pro_action = self._denormalize_action(action=pro_action, original_space=self._env.pro_action_space)
        adv_action = self._denormalize_action(action=adv_action, original_space=self._env.adv_action_space)

        obs = {}
        rews = {}
        dones = {"__all__": False}
        infos = {}

        observation, rew, done, _ = self._env.step(_JointAction(pro=pro_action, adv=adv_action))

        observation = self._post_process_obs(observation=observation)

        obs[acting_agent_id] = observation
        rews[acting_agent_id] = rew  # base env reward
        dones[acting_agent_id] = done
        infos[acting_agent_id] = {}

        if self.adversary_agent_id is not None:
            obs[self.adversary_agent_id] = observation
            rews[self.adversary_agent_id] = -rew
            dones[self.adversary_agent_id] = False
            infos[self.adversary_agent_id] = {}

        if done:
            if self._acting_agent_idx + 1 >= len(self.standard_agent_ids):
                # All agents have finished. Episode is over.
                dones = {k: True for k in dones.keys()}

                infos[acting_agent_id] = {"full_episode_completed": True,
                                          # "params_were_feasible": heuristic_were_params_feasible,
                                          # "oracle_protag_reward": not implemented
                                          # "cart_mass": self._env.masscart,
                                          # "pole_mass": self._env.masspole
                                          }
            else:
                self._acting_agent_idx += 1
                new_acting_agent_id = self.standard_agent_ids[self._acting_agent_idx]

                new_acting_agent_observation = self._env.reset()  # enter next agent's phase
                new_acting_agent_observation = self._post_process_obs(observation=new_acting_agent_observation)

                obs[new_acting_agent_id] = new_acting_agent_observation
                rews[new_acting_agent_id] = None  # new agent just started mid-episode
                dones[new_acting_agent_id] = False
                infos[new_acting_agent_id] = {}

                if self.adversary_agent_id is not None:
                    obs[self.adversary_agent_id] = new_acting_agent_observation
                    # adversary rew is unmodified from previous standard agent
                    dones[self.adversary_agent_id] = False
                    infos[self.adversary_agent_id] = {}

        return obs, rews, dones, infos

    def _sample_denormalized_adversary_action_given_beta_dist_params(self, original_space: Box):
        if self._x_axis_only:
            # We're going to ignore the 2nd beta distribution

            if self._x_axis_alpha_override is not None:
                alpha_params = [self._x_axis_alpha_override, self._x_axis_alpha_override]
                beta_params = [self._x_axis_beta_override, self._x_axis_beta_override]
            else:
                assert len(self._adversary_distribution_params) == 2, self._adversary_distribution_params
                alpha_params = [self._adversary_distribution_params[0], self._adversary_distribution_params[0]]
                beta_params = [self._adversary_distribution_params[1], self._adversary_distribution_params[1]]
        else:
            alpha_params, beta_params = np.split(self._adversary_distribution_params, 2)

        adv_action = np.random.beta(a=alpha_params, b=beta_params)
        assert (adv_action >= 0.0).all() and (adv_action <= 1.0).all(), adv_action

        space_range = np.asarray(original_space.high) - np.asarray(original_space.low)
        adv_action = (adv_action * space_range) + original_space.low
        if self._x_axis_only:
            # nullify y component of force
            adv_action[1] = 0.0
        assert adv_action in original_space, adv_action
        return adv_action

    def set_x_axis_beta_dist_overrides(self, alpha, beta):
        self._x_axis_alpha_override = alpha
        self._x_axis_beta_override = beta

    def _step_with_adversary_providing_distribution(self, action_dict):
        if len(action_dict) > 1:
            raise ValueError("Only one agent ever acts in the environment at a time. "
                             f"Action dict contained multiple actions: {action_dict}")

        acting_agent_id = self.all_agent_ids[self._acting_agent_idx]
        if not acting_agent_id in action_dict:
            raise ValueError(f"Acting agent id: {acting_agent_id} wasn't found in the action dict.")
        action = action_dict[acting_agent_id]
        if action not in self.agent_ids_to_action_spaces[acting_agent_id]:
            raise ValueError(f"Action for agent id {acting_agent_id} wasn't a part of it's "
                             f"assigned action space: {self.agent_ids_to_action_spaces[acting_agent_id]}")

        obs = {}
        rews = {}
        dones = {"__all__": False}
        infos = {}

        if acting_agent_id == self.adversary_agent_id:

            # move is by adversary
            # shift from the [-1, 1] range to the (0, 10] range to be used as Beta distribution params
            action = ((action + 1.0) / 2.0) * 10.0
            action[action == 0] = 1e-8
            assert (action > 0.0).all() and (action <= 10.0).all(), action
            self._adversary_distribution_params = action

            done = True

        else:
            # move is by a protagonist or antagonist
            pro_action = self._denormalize_action(action=action, original_space=self._env.pro_action_space)
            adv_action = self._sample_denormalized_adversary_action_given_beta_dist_params(
                original_space=self._env.adv_action_space)

            observation, rew, done, _ = self._env.step(_JointAction(pro=pro_action, adv=adv_action))

            # observation, rew, done, _ = self._env.step(action) # TODO WOAH, this probably shouldnt be here!!!!
            obs[acting_agent_id] = self._post_process_obs(observation=observation)
            rews[acting_agent_id] = rew  # base env reward
            if acting_agent_id == 0:
                self._default_protag_reward_this_episode += rew
            dones[acting_agent_id] = done
            infos[acting_agent_id] = {}

        if done:
            if self._acting_agent_idx + 1 >= len(self.all_agent_ids):
                # All agents have finished. Episode is over.
                dones["__all__"] = True

                infos[acting_agent_id] = {"full_episode_completed": True}

                if self.adversary_agent_id is not None:
                    obs[self.adversary_agent_id] = np.zeros(shape=self.observation_space["adversary"].shape, dtype=np.float32)
                    rews[self.adversary_agent_id] = -self._default_protag_reward_this_episode
                    dones[self.adversary_agent_id] = True
                    infos[self.adversary_agent_id] = {}

                if self._x_axis_only:
                    if self._x_axis_alpha_override is not None:
                        alpha = self._x_axis_alpha_override
                        beta = self._x_axis_beta_override
                    else:
                        assert len(self._adversary_distribution_params) == 2, self._adversary_distribution_params
                        alpha = self._adversary_distribution_params[0]
                        beta = self._adversary_distribution_params[1]
                    infos[acting_agent_id]["beta_0_mean"] = scipy.stats.beta.mean(a=alpha,
                                                                                  b=beta,
                                                                                  loc=-self._env_max_force,
                                                                                  scale=self._env_max_force * 2.0)
                    infos[acting_agent_id]["beta_0_var"] = scipy.stats.beta.var(a=alpha,
                                                                                b=beta,
                                                                                loc=-self._env_max_force,
                                                                                scale=self._env_max_force * 2.0)
                    infos[acting_agent_id]["beta_0_max"] = scipy.stats.beta.interval(1, a=alpha,
                                                                                     b=beta,
                                                                                     loc=-self._env_max_force,
                                                                                     scale=self._env_max_force * 2.0)[1]
                    infos[acting_agent_id]["beta_0_min"] = scipy.stats.beta.interval(1, a=alpha,
                                                                                     b=beta,
                                                                                     loc=-self._env_max_force,
                                                                                     scale=self._env_max_force * 2.0)[0]
                else:
                    alpha_params, beta_params = np.split(self._adversary_distribution_params, 2)
                    infos[acting_agent_id]["beta_0_mean"] = scipy.stats.beta.mean(a=alpha_params[0],
                                                                                  b=beta_params[0],
                                                                                  loc=-self._env_max_force,
                                                                                  scale=self._env_max_force*2.0)
                    infos[acting_agent_id]["beta_0_var"] = scipy.stats.beta.var(a=alpha_params[0],
                                                                                  b=beta_params[0],
                                                                                  loc=-self._env_max_force,
                                                                                  scale=self._env_max_force * 2.0)
                    infos[acting_agent_id]["beta_1_mean"] = scipy.stats.beta.mean(a=alpha_params[1],
                                                                                  b=beta_params[1],
                                                                                  loc=-self._env_max_force,
                                                                                  scale=self._env_max_force * 2.0)
                    infos[acting_agent_id]["beta_1_var"] = scipy.stats.beta.var(a=alpha_params[1],
                                                                                  b=beta_params[1],
                                                                                  loc=-self._env_max_force,
                                                                                  scale=self._env_max_force * 2.0)

                    infos[acting_agent_id]["beta_0_max"] = scipy.stats.beta.interval(1, a=alpha_params[0],
                                                                                  b=beta_params[0],
                                                                                  loc=-self._env_max_force,
                                                                                  scale=self._env_max_force*2.0)[1]
                    infos[acting_agent_id]["beta_0_min"] = scipy.stats.beta.interval(1, a=alpha_params[0],
                                                                                  b=beta_params[0],
                                                                                  loc=-self._env_max_force,
                                                                                  scale=self._env_max_force*2.0)[0]
            else:
                self._acting_agent_idx += 1
                new_acting_agent_id = self.all_agent_ids[self._acting_agent_idx]

                new_acting_agent_observation = self._env.reset()  # enter next agent's phase

                obs[new_acting_agent_id] = self._post_process_obs(observation=new_acting_agent_observation)
                rews[new_acting_agent_id] = None  # new agent just started mid-episode
                dones[new_acting_agent_id] = False
                infos[new_acting_agent_id] = {}

        return obs, rews, dones, infos

    def render(self, mode="human"):
        return self._env.render(mode=mode)


